"""
用来放回测的各种工具，比如将position 转换为订单的工具，这样策略就可以直接输出持仓信号而不是自己去计算交易信号
"""
from typing import Union
import time
import logging
from enum_var import Action


def pctCheck(func):
    def inner(*args, **kwargs):
        temp = func(*args, **kwargs)
        if temp:
            if len(temp) == 2:
                for i in range(2):
                    if temp[i][2] == 0:
                        if i:  # 之前写的是元组，不能pop所以必须按下标取不为0 的哪一个
                            return temp[0]
                        else:
                            return temp[1]
                return temp

            elif len(temp) == 3:
                if temp[2] != 0:
                    return temp
                else:
                    return None

    return inner


@pctCheck
def pct2vol(contract: str, pct: float, pos: float, cash: float, price: float) -> Union[tuple, None]:
    """
    :param
    :param pct: 传入float, 希望的比例 多头为正，空头为负，绝对值合起来最多不超过1
    :param pos: 传入float, key:合约，value:已有的比例 多头为正，空头为负，绝对值合起来最多不超过1
    :return: tuple 里面是可以直接调用去creat market order的（'SELL_SHORT', 'SHFE.rb2010', 20）, 也有可能是None， 也有可能是两个tuple，一平一开
    """
    assert cash > 0
    assert price > 0
    vol = int((abs(pct) - abs(pos)) * cash / price)
    if pct > 0 and pos > 0:
        if abs(pct) > abs(pos):  # 开多
            return (Action.BUY, contract, vol)

        elif abs(pct) < abs(pos):  # 平多
            return (Action.SELL, contract, vol)

        elif abs(pct) == abs(pos):  # 无动作
            return None

        else:
            raise Exception('error!1')

    elif pct < 0 and pos < 0:
        if abs(pct) > abs(pos):  # 开空
            return (Action.SELL_SHORT, contract, vol)

        elif abs(pct) < abs(pos):  # 平空
            return (Action.BUY_TO_COVER, contract, vol)

        elif abs(pct) == abs(pos):  # 无动作
            return None

        else:
            raise Exception('error!2')

    elif pct > 0 and pos < 0:  # 先平再开
        return ((Action.BUY_TO_COVER, contract, int(abs(pos) * cash / price)), (Action.BUY, contract, int(abs(pct) * cash / price)))

    elif pct < 0 and pos > 0:
        return ((Action.SELL, contract, int(abs(pos) * cash / price)), (Action.SELL_SHORT, contract, int(abs(pct) * cash / price)))

    elif pct == 0 and pos > 0:
        return (Action.SELL, contract, int(abs(pos) * cash / price))

    elif pct == 0 and pos < 0:
        return (Action.BUY_TO_COVER, contract, int(abs(pos) * cash / price))

    elif pct < 0 and pos == 0:
        return (Action.SELL_SHORT, contract, int(abs(pct) * cash / price))

    elif pct > 0 and pos == 0:
        return (Action.BUY, contract, int(abs(pct) * cash / price))

    elif pct == 0 and pos == 0:
        return None

    else:
        raise Exception('error3')


def time_it(func):
    count=0
    sumTime = 0
    def inner(*args, **kwargs):
        nonlocal sumTime,count
        # timer.debug(str(count) + 'start time')
        start = time.time()
        func(*args, **kwargs)
        end = time.time()
        delta = end - start
        timer.debug(str(delta))
        # timer.debug(str(count) + 'end time')
        sumTime += delta
        timer.debug('sum time:' + str(sumTime))
        count+=1
    return inner


if __name__ == '__main__' :
    timer = logging.getLogger('timer')
    fm = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    fh = logging.FileHandler('log.txt', mode='a')
    fh.setFormatter(fm)
    fh.setLevel(logging.DEBUG)

    timer.setLevel(logging.DEBUG)
    timer.addHandler(fh)
    timer.propagate = False